{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6.5 卷积神经网络迁移学习\n",
    "### 6.5.1 迁移学习介绍\n",
    "在6.4节中介绍了1998年提出的LeNet5模型和2015年提出的Inception-v3模型。对别二者可以发现CNN的层数和复杂度都发生了巨大的变化，下表给出了2012-2015年ILSVRC第一名模型的层数（在这里层数只计算了卷积层和全连接层的个数，没有参数的池化层没有包括在内）及top5的错误率。\n",
    "<p align='center'>\n",
    "    <img src=images/表6.2.JPG>\n",
    "</p>\n",
    "\n",
    "从上表可以看到，随着模型层数及复杂度的增加，模型在ImageNet上的错误率也随之降低。然而，训练、复杂的卷积神经网络需要非常多的标注数据。如6.1节中提到的，ImageNet图像分类数据集中有120万标注图片，所以才能将152层的ResNet的模型训练到大约96.5% 的正确率。在其实的应用中，很难收集到如此多的标注数据。即使可以收集到，也需要花费大量人力物力。而且即使有海量的训练数据，要训练一个复杂的卷积神经网络也需要几天甚至几周的时间。为了解决标注数据和训练时间的问题，可以使用本节将要介绍的迁移学习。\n",
    "\n",
    "**所谓迁移学习，就是将一个问题上训练好的模型通过简单的调整使其适用于一个新的问题**。本节将介绍如何利用ImageNet数据集上训练好的Inception-v3 模型来解决一个新的图像分类问题。根据论文[DeCAF: A Deep Convolutional Activation Feature for Generic Visual Recognition](http://proceedings.mlr.press/v32/donahue14.pdf)中的结论：`可以保留训练好Inception-v3模型中所有卷积层的参数，只是替换最后一层全连接层`。在最后这一层全连接层之前的网络层称之为瓶颈层（bottleneck）。\n",
    "\n",
    "- 原理：**将新的图像通过训练好的卷积神经网络直到瓶颈层的过程可以看成是对图像进行特征提取的过程**。在训练好的Inception-v3模型中，因为将瓶颈层的输出再通过一个单层的全连接层神经网络可以很好地区分1000种类别的图像，所以有理由认为**瓶颈层输出的节点向量可以被作为任何图像的一个更加精简且表达能力更强的特征向量。**于是，在新数据集上，可以直接利用这个训练好的神经网络对图像进行特征提取，然后再将提取得到的特征向量作为输入来训练一个新的单层全连接神经网络处理新的分类问题。\n",
    "\n",
    "- 适用条件：一般来说，在数据量足够的情况下，迁移学习的效果不如完全重新训练。但是迁移学习所需要的训练时间和训练样本数要远远小于训练完整的模型。在没有GPU的普通台式机或者笔记本电脑上，6.5.2节中给出的TensorFlow训练过程只需要大约3小时，而且可以达到大概90%的正确率。\n",
    "\n",
    "**TensorFlow实现迁移学习实例**\n",
    "\n",
    "这里使用到的数据集是[http://download.tensorflow.org/example_images/flower_photos.tgz](http://download.tensorflow.org/example_images/flower_photos.tgz)，解压后的文件夹包含5个子文件夹，每一个子文件夹的名称为一种花的名称，代表了不同的类别。平均每种花有734张图片，每一张图片都是RGB色彩模式的，大小也不相同。和之前的样例不同，在这一节中给出的程序将直接处理没有整理过的图像数据。\n",
    "\n",
    "**1. 数据处理**："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "processing: daisy. total 300 images.\n",
      "100 images processed.\n",
      "200 images processed.\n",
      "300 images processed.\n",
      "processing: dandelion. total 300 images.\n",
      "100 images processed.\n",
      "200 images processed.\n",
      "300 images processed.\n",
      "processing: roses. total 300 images.\n",
      "100 images processed.\n",
      "200 images processed.\n",
      "300 images processed.\n",
      "processing: sunflowers. total 300 images.\n",
      "100 images processed.\n",
      "200 images processed.\n",
      "300 images processed.\n",
      "processing: tulips. total 300 images.\n",
      "100 images processed.\n",
      "200 images processed.\n",
      "300 images processed.\n"
     ]
    }
   ],
   "source": [
    "import glob\n",
    "import os.path\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.platform import gfile\n",
    "\n",
    "# 原始输入数据的目录，这个目录下有5个子目录，每个子目录底下保存这属于该类别的所有图片。\n",
    "INPUT_DATA = '../../datasets/flower_photos'\n",
    "# 输出文件地址。我们将整理后的图片数据通过numpy的格式保存。\n",
    "OUTPUT_FILE = '../../datasets/flower_processed_data.npy'\n",
    "\n",
    "# 测试数据和验证数据比例。\n",
    "VALIDATION_PERCENTAGE = 10\n",
    "TEST_PERCENTAGE = 10\n",
    "\n",
    "\n",
    "# 读取数据并将数据分割成训练数据、验证数据和测试数据。\n",
    "def create_image_lists(sess, testing_percentage, validation_percentage):\n",
    "    sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]\n",
    "    is_root_dir = True\n",
    "    \n",
    "    # 初始化各个数据集。\n",
    "    training_images = []\n",
    "    training_labels = []\n",
    "    validation_images = []\n",
    "    validation_labels = []\n",
    "    testing_images = []\n",
    "    testing_labels = []\n",
    "    current_label = 0\n",
    "    \n",
    "    # 读取所有的子目录。\n",
    "    for sub_dir in sub_dirs:\n",
    "        if is_root_dir:        # sub_dirs第一个元素为根目录路径\n",
    "            is_root_dir = False\n",
    "            continue\n",
    "\n",
    "        # 获取一个子目录中所有的图片文件。\n",
    "        extensions = ['jpg']\n",
    "        file_list = []\n",
    "        dir_name = os.path.basename(sub_dir)    # 返回path最后的文件名，/或\\结尾则返回空值\n",
    "        for extension in extensions:\n",
    "            file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)     #  ../../../datasets/flower_photos\\daisy\\*.jpg\n",
    "            file_list.extend(glob.glob(file_glob))    # 匹配所有符合条件文件，并以list返回\n",
    "        if not file_list: continue\n",
    "        print(\"processing: %s. total %d images.\"% (dir_name, len(file_list)))\n",
    "        \n",
    "        i = 0\n",
    "        # 处理图片数据。\n",
    "        for file_name in file_list:\n",
    "            i += 1\n",
    "            # 读取并解析图片，将图片转化为299*299以方便inception-v3模型来处理。\n",
    "            image_raw_data = gfile.FastGFile(file_name, 'rb').read()\n",
    "            image = tf.image.decode_jpeg(image_raw_data)\n",
    "            if image.dtype != tf.int32:\n",
    "                image = tf.image.convert_image_dtype(image, dtype=tf.int32)\n",
    "            image = tf.image.resize_images(image, [299, 299])\n",
    "            image_value = sess.run(image)\n",
    "            \n",
    "            # 随机划分数据集。\n",
    "            chance = np.random.randint(100)\n",
    "            if chance < validation_percentage:\n",
    "                validation_images.append(image_value)\n",
    "                validation_labels.append(current_label)\n",
    "            elif chance < (testing_percentage + validation_percentage):\n",
    "                testing_images.append(image_value)\n",
    "                testing_labels.append(current_label)\n",
    "            else:\n",
    "                training_images.append(image_value)\n",
    "                training_labels.append(current_label)\n",
    "            if i % 100 == 0:\n",
    "                print(i, \"images processed.\")\n",
    "        current_label += 1\n",
    "    \n",
    "    # 将训练数据随机打乱以获得更好的训练效果。\n",
    "    state = np.random.get_state()\n",
    "    np.random.shuffle(training_images)\n",
    "    np.random.set_state(state)\n",
    "    np.random.shuffle(training_labels)\n",
    "    \n",
    "    return np.asarray([training_images, training_labels,\n",
    "                       validation_images, validation_labels,\n",
    "                       testing_images, testing_labels])\n",
    "\n",
    "\n",
    "# 运行数据处理过程\n",
    "with tf.Session() as sess:\n",
    "    processed_data = create_image_lists(sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)\n",
    "    # 通过numpy格式保存处理后的数据。\n",
    "    np.save(OUTPUT_FILE, processed_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**2. 迁移学习：**\n",
    "\n",
    "运行以上代码可以将所有的图片数据划分为训练集、验证集和测试集，并且将原始jpg格式转化为inception-v3模型需要的299\\*299\\*3的数字矩阵（*由于本人太穷，原数据集全部使用电脑会MemError，这里取每类图片后300张作为运行数据集*）。\n",
    "\n",
    "谷歌提供的训练好的inception-v3模型地址如下：[http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz](http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)。模型源代码[在这](https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_v3.py)。\n",
    "\n",
    "当新的数据集和预训练好的模型准备好之后，就可以通过下面代码来完成迁移学习了："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1231 training examples, 128 validation examples and 141 testing examples.\n",
      "Loading tuned variables from ../../../datasets/inception_v3.ckpt\n",
      "INFO:tensorflow:Restoring parameters from ../../../datasets/inception_v3.ckpt\n",
      "Step 0: Training loss is 1.9 Validation accuracy = 17.2%\n",
      "Step 30: Training loss is 1.7 Validation accuracy = 21.9%\n",
      "Step 60: Training loss is 1.6 Validation accuracy = 49.2%\n",
      "Step 90: Training loss is 1.2 Validation accuracy = 79.7%\n",
      "Step 120: Training loss is 0.6 Validation accuracy = 89.1%\n",
      "Step 150: Training loss is 0.3 Validation accuracy = 93.0%\n",
      "Step 180: Training loss is 0.2 Validation accuracy = 93.0%\n",
      "Step 210: Training loss is 0.2 Validation accuracy = 93.0%\n",
      "Step 240: Training loss is 0.2 Validation accuracy = 92.2%\n",
      "Step 270: Training loss is 0.2 Validation accuracy = 92.2%\n",
      "Step 299: Training loss is 0.2 Validation accuracy = 92.2%\n",
      "Final test accuracy = 91.5%\n"
     ]
    }
   ],
   "source": [
    "import glob\n",
    "import os.path\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.platform import gfile\n",
    "import tensorflow.contrib.slim as slim\n",
    "# 加载通过TensorFlow-Slim定义好的inception_v3模型。\n",
    "import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3\n",
    "\n",
    "# 1. 定义训练过程中将要使用到的常量。\n",
    "# 处理好之后的数据文件。\n",
    "INPUT_DATA = '../../datasets/flower_processed_data.npy'\n",
    "# 保存训练好的模型的路径。这里可以将使用新数据训练得到的模型保存下来，如果计算充足，\n",
    "# 还可以在训练完最后的全连接层之后再训练所有的网络层，使模型更贴近新数据。\n",
    "TRAIN_FILE = 'train_dir/model'\n",
    "# 谷歌提供的训练好的模型文件地址。因为GitHub无法保存大于100M的文件，所以\n",
    "# 在运行时需要先自行从Google下载inception_v3.ckpt文件。\n",
    "CKPT_FILE = '../../datasets/inception_v3.ckpt'\n",
    "\n",
    "# 定义训练中使用的参数。\n",
    "LEARNING_RATE = 0.0001\n",
    "STEPS = 300\n",
    "BATCH = 32\n",
    "N_CLASSES = 5\n",
    "\n",
    "# 不需要从谷歌训练好的模型中加载的参数。这里就是最后的全连接层。这里给出的是参数的前缀\n",
    "CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'\n",
    "# 需要训练的网络层参数明层，在fine-tuning的过程中就是最后的全联接层。这里给出的是参数的前缀\n",
    "TRAINABLE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogit'\n",
    "\n",
    "\n",
    "# 2. 获取所有需要从谷歌训练好的模型中加载的参数。\n",
    "def get_tuned_variables():\n",
    "    exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(',')]\n",
    "    variables_to_restore = []\n",
    "    \n",
    "    # 枚举inception-v3模型中所有的参数，然后判断是否需要从加载列表中移除。\n",
    "    for var in slim.get_model_variables():\n",
    "        excluded = False\n",
    "        for exclusion in exclusions:\n",
    "            if var.op.name.startswith(exclusion):\n",
    "                excluded = True\n",
    "                break\n",
    "        if not excluded:\n",
    "            variables_to_restore.append(var)\n",
    "    return variables_to_restore\n",
    "\n",
    "\n",
    "# 3. 获取所有需要训练的变量列表。\n",
    "def get_trainable_variables():    \n",
    "    scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(',')]\n",
    "    variables_to_train = []\n",
    "    \n",
    "    # 枚举所有需要训练的参数前缀，并通过这些前缀找到所有需要训练的参数。\n",
    "    for scope in scopes:\n",
    "        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)\n",
    "        variables_to_train.extend(variables)\n",
    "    return variables_to_train\n",
    "\n",
    "\n",
    "# 4. 定义训练过程\n",
    "def main():\n",
    "    # 0. 加载预处理好的数据。\n",
    "    processed_data = np.load(INPUT_DATA)\n",
    "    training_images = processed_data[0]\n",
    "    n_training_example = len(training_images)\n",
    "    training_labels = processed_data[1]\n",
    "    \n",
    "    validation_images = processed_data[2]\n",
    "    validation_labels = processed_data[3]\n",
    "    \n",
    "    testing_images = processed_data[4]\n",
    "    testing_labels = processed_data[5]\n",
    "    print(\"%d training examples, %d validation examples and %d testing examples.\" % (\n",
    "        n_training_example, len(validation_labels), len(testing_labels)))\n",
    "\n",
    "    # 1. 定义inception-v3的输入输出\n",
    "    images = tf.placeholder(tf.float32, [None, 299, 299, 3], name='input_images')\n",
    "    labels = tf.placeholder(tf.int64, [None], name='labels')\n",
    "    \n",
    "    # 2. 定义前向传播、损失函数、反向传播\n",
    "    # 定义inception-v3模型。因为谷歌给出的只有模型参数取值，所以这里\n",
    "    # 需要在这个代码中定义inception-v3的模型结构。虽然理论上需要区分训练和\n",
    "    # 测试中使用到的模型，也就是说在测试时应该使用is_training=False，但是\n",
    "    # 因为预先训练好的inception-v3模型中使用的batch normalization参数与\n",
    "    # 新的数据会有出入，所以这里直接使用同一个模型来做测试。\n",
    "    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):\n",
    "        logits, _ = inception_v3.inception_v3(\n",
    "            images, num_classes=N_CLASSES, is_training=True)\n",
    "    \n",
    "    # 定义损失函数\n",
    "    trainable_variables = get_trainable_variables()    # 自定义函数，获取需要训练的变量\n",
    "    tf.losses.softmax_cross_entropy(\n",
    "        tf.one_hot(labels, N_CLASSES), logits, weights=1.0)\n",
    "    total_loss = tf.losses.get_total_loss()\n",
    "    \n",
    "    # 定义反向传播\n",
    "    train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(total_loss)\n",
    "    \n",
    "    # 计算正确率。\n",
    "    with tf.name_scope('evaluation'):\n",
    "        correct_prediction = tf.equal(tf.argmax(logits, 1), labels)\n",
    "        evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "                \n",
    "    # 定义加载Google训练好的Inception-v3模型的Saver。\n",
    "    load_fn = slim.assign_from_checkpoint_fn(\n",
    "      CKPT_FILE,\n",
    "      get_tuned_variables(),                          # 自定义函数，获取需要加载的参数\n",
    "      ignore_missing_vars=True)\n",
    "    \n",
    "    # 3. 建立会话，训练模型\n",
    "    saver = tf.train.Saver()      # 定义保存新模型的Saver。\n",
    "    with tf.Session() as sess:\n",
    "        # 初始化没有加载进来的变量。注意这个过程要在加载模型之前，\n",
    "        # 否则初始化过程会将已经加载好的变量重新赋值。\n",
    "        init = tf.global_variables_initializer()\n",
    "        sess.run(init)\n",
    "        \n",
    "        # 加载谷歌已经训练好的模型。\n",
    "        print('Loading tuned variables from %s' % CKPT_FILE)\n",
    "        load_fn(sess)\n",
    "            \n",
    "        start = 0\n",
    "        end = BATCH\n",
    "        for i in range(STEPS):\n",
    "            # 运行训练过程，这里不会更新全部参数，只会更新指定部分的参数\n",
    "            _, loss = sess.run([train_step, total_loss], feed_dict={\n",
    "                images: training_images[start:end], \n",
    "                labels: training_labels[start:end]})\n",
    "\n",
    "            # 保存并输出日志\n",
    "            if i % 30 == 0 or i + 1 == STEPS:\n",
    "                saver.save(sess, TRAIN_FILE, global_step=i)\n",
    "                validation_accuracy = sess.run(evaluation_step, feed_dict={\n",
    "                    images: validation_images, labels: validation_labels})\n",
    "                print('Step %d: Training loss is %.1f Validation accuracy = %.1f%%' % (\n",
    "                    i, loss, validation_accuracy * 100.0))\n",
    "                         \n",
    "            # 获取下一个batch\n",
    "            start = end\n",
    "            if start == n_training_example:\n",
    "                start = 0\n",
    "            end = start + BATCH\n",
    "            if end > n_training_example:\n",
    "                end = n_training_example\n",
    "            \n",
    "        # 在最后的测试数据上测试正确率。\n",
    "        test_accuracy = sess.run(evaluation_step, feed_dict={\n",
    "            images: testing_images, labels: testing_labels})\n",
    "        print('Final test accuracy = %.1f%%' % (test_accuracy * 100))\n",
    "        \n",
    "        \n",
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到，模型在新的数据集上能够很快收敛，并达到还不错的分类效果。（虽然我只取了1500/3670张图片，最后的accuracy也基本达到了书中全部数据的准确率91.9%，可见迁移学习的效果非常好！）"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
